from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers import CONFIG_MAPPING

logger = logging.get_logger(__name__)

class AttrDict:
    def __init__(self, entries):
        for key, value in entries.items():
            if isinstance(value, dict):
                entries[key] = AttrDict(value)
        self.__dict__.update(entries)

    def get(self, key, default_val=None):
        return self.__dict__.get(key, default_val)

class VisionConfig(PretrainedConfig):
    model_type = "vision"
    cls: str = ""
    params: AttrDict = {}

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.cls = kwargs.get("cls", "")
        if not isinstance(self.cls, str):
            self.cls = self.cls.__name__

        self.params = AttrDict(kwargs.get("params", {}))

class AlignerConfig(PretrainedConfig):
    model_type = "aligner"
    cls: str = ""
    params: AttrDict = {}

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.cls = kwargs.get("cls", "")
        if not isinstance(self.cls, str):
            self.cls = self.cls.__name__

        self.params = AttrDict(kwargs.get("params", {}))

class UrsaConfig(PretrainedConfig):
    model_type = "ursa"
    is_composition = False
    vision_config: VisionConfig
    aligner_config: AlignerConfig
    text_config: PretrainedConfig

    def __init__(
        self,
        vision_config=None,
        aligner_config=None,
        text_config=None,
        ignore_index=-100,
        image_token_index=32000,
        projector_hidden_act="gelu",
        vision_feature_select_strategy="default",
        vision_feature_layer=-2,
        **kwargs,
    ):
        self.ignore_index = ignore_index
        self.image_token_index = image_token_index
        self.projector_hidden_act = projector_hidden_act

        if vision_feature_select_strategy not in ["default", "full"]:
            raise ValueError(
                "vision_feature_select_strategy should be one of 'default', 'full'."
                f"Got: {vision_feature_select_strategy}"
            )

        self.vision_feature_select_strategy = vision_feature_select_strategy
        self.vision_feature_layer = vision_feature_layer

        if vision_config is None:
            vision_config = VisionConfig()
            vision_config.cls = "HybridVisionTower"
            vision_config.params = {
                            "concat_type": "tuple",
                            "high_res_cfg": {
                                "ckpt_path": "",
                                "image_size": 1024,
                                "model_name": "sam_b_downsample",
                                "output_dim": 1024,
                                "pixel_mean": [
                                0.48145466,
                                0.4578275,
                                0.40821073
                                ],
                                "pixel_std": [
                                0.26862954,
                                0.26130258,
                                0.27577711
                                ],
                                "select_feature": "same",
                                "select_layer": -1
                            },
                            "low_res_cfg": {
                                "ckpt_path": "",
                                "image_size": 384,
                                "model_name": "siglip_large_patch16_384",
                                "output_dim": 1024,
                                "pixel_mean": [
                                0.5,
                                0.5,
                                0.5
                                ],
                                "pixel_std": [
                                0.5,
                                0.5,
                                0.5
                                ],
                                "select_feature": "same",
                                "select_layer": -1
                            }
                        }
        self.vision_config = vision_config
        self.aligner_config = aligner_config
        if isinstance(text_config, dict):
            text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
            text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
        elif text_config is None:
            text_config = CONFIG_MAPPING["llama"]()

        self.text_config = text_config

        super().__init__(**kwargs)